# -*- coding: utf-8 -*-
# 导入标准库
import json
import os
# 导入 PyTorch
import torch
import sys
from pathlib import Path
bert_query_classifier_PATH ='./model/bert_query_classifier'
BERT_PATH = './model/bert-base-chinese'
# 导入日志
import logger
# 导入numpy
import numpy as np
# 导入 Transformers 库
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import Trainer, TrainingArguments
# 导入train_test_split
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
# 导入魔搭社区库 - 关键修改点
# from modelscope.trainers import Trainer
# from modelscope.trainers.training_args import TrainingArguments
# from swift import Trainer, TrainingArguments
from modelscope import AutoModelForSequenceClassification,AutoModel, AutoTokenizer
import bert_dawnload
logger = logger.logger
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 魔搭社区模型ID - 关键修改点
BERT_MODEL_ID = 'tiansz/bert-base-chinese'  # 使用魔搭社区模型ID
bert_query_classifier_PATH = './model/bert_query_classifier'
# 本地模型路径（魔搭格式）
BERT_PATH = './model/bert/bert-base-chinese'  # 本地预训练模型目录
CLASSIFIER_PATH = './model/bert_query_classifier'  # 分类器保存路径
class QueryClassifier:
    def __init__(self, model_path=str(bert_query_classifier_PATH)):
        # 初始化模型路径
        self.model_path = model_path
        if not os.path.exists(BERT_PATH):
            raise FileNotFoundError(f"本地模型路径不存在: {BERT_PATH}")

        # 使用魔搭社区的AutoTokenizer - 关键修改点
        # self.tokenizer = AutoTokenizer.from_pretrained(BERT_MODEL_ID)
        # 加载魔搭分词器（强制本地）
        # self.tokenizer = AutoTokenizer.from_pretrained(
        #     BERT_PATH,
        #     local_files_only=True  # 关键：禁止下载
        # )
        # 加载 BERT 分词器
        self.tokenizer = BertTokenizer.from_pretrained(str(BERT_PATH))
        # 初始化模型
        self.model = None
        # 确定设备（GPU 或 CPU）
        self.device = device
        # 记录设备信息
        logger.info(f"使用设备: {self.device}")
        # 定义标签映射
        self.label_map = {"通用知识": 0, "专业咨询": 1}
        # 加载模型
        self.load_model()

    def load_model(self):
        # 检查模型路径是否存在
        if os.path.exists(self.model_path):

            # # 加载预训练模型
            self.model = BertForSequenceClassification.from_pretrained(self.model_path)
            # 将模型移到指定设备
            self.model.to(self.device)
            # 记录加载成功的日志
            logger.info(f"加载模型: {self.model_path}")
        else:

            self.model = BertForSequenceClassification.from_pretrained(str(BERT_PATH), num_labels=2)
            # 将模型移到指定设备
            self.model.to(self.device)
            # 记录初始化模型的日志
            logger.info("初始化新 BERT 模型")

    def save_model(self):
        """保存模型"""
        self.model.save_pretrained(self.model_path)
        self.tokenizer.save_pretrained(self.model_path)
        logger.info(f"模型保存至: {self.model_path}")

    def preprocess_data(self, texts, labels):
        """预处理数据为 BERT 输入格式"""
        encodings = self.tokenizer(
            texts,
            truncation=True,
            padding=True,
            max_length=128,
            return_tensors="pt"
        )
        return encodings, [self.label_map[label] for label in labels]

    def create_dataset(self, encodings, labels):
        """创建 PyTorch 数据集"""

        class Dataset(torch.utils.data.Dataset):
            def __init__(self, encodings, labels):
                self.encodings = encodings
                self.labels = labels

            def __getitem__(self, idx):
                item = {key: val[idx] for key, val in self.encodings.items()}
                item["labels"] = torch.tensor(self.labels[idx])
                return item

            def __len__(self):
                return len(self.labels)

        return Dataset(encodings, labels)

    def train_model(self, data_file="training_dataset_hybrid_5000.json"):
        """训练 BERT 分类模型"""
        # 加载数据集
        if not os.path.exists(data_file):
            logger.error(f"数据集文件 {data_file} 不存在")
            raise FileNotFoundError(f"数据集文件 {data_file} 不存在")

        with open(data_file, "r", encoding="utf-8") as f:
            data = [json.loads(value) for value in f.readlines()]

        texts = [item["query"] for item in data]
        labels = [item["label"] for item in data]

        # 数据划分
        train_texts, val_texts, train_labels, val_labels = train_test_split(
            texts, labels, test_size=0.2, random_state=42
        )

        # 预处理
        train_encodings, train_labels = self.preprocess_data(train_texts, train_labels)
        val_encodings, val_labels = self.preprocess_data(val_texts, val_labels)

        # 创建数据集
        train_dataset = self.create_dataset(train_encodings, train_labels)
        # print(f'train_dataset--》{train_dataset[0]}')
        val_dataset = self.create_dataset(val_encodings, val_labels)
        #
        # 设置训练参数
        training_args = TrainingArguments(
            output_dir="./bert_results",
            num_train_epochs=3,
            per_device_train_batch_size=8,
            per_device_eval_batch_size=8,
            warmup_steps=500,
            weight_decay=0.01,
            logging_dir="./bert_logs",
            logging_steps=10,
            # evaluation_strategy="epoch",已弃用
            eval_strategy="epoch",
            save_strategy="epoch",
            load_best_model_at_end=True,
            save_total_limit=1,  # 只保存一个检查点，即最优的模型
            metric_for_best_model="eval_loss",
            fp16=False,  # 禁用混合精度
        )

        # 初始化 Trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            compute_metrics=self.compute_metrics
        )


        # 训练模型
        logger.info("开始训练 BERT 模型...")
        trainer.train()
        self.save_model()

        # 评估模型
        self.evaluate_model(val_texts, val_labels)

    def compute_metrics(self, eval_pred):
        """计算评估指标"""
        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)
        accuracy = (predictions == labels).mean()
        return {"accuracy": accuracy}

    def evaluate_model(self, texts, labels):
        """评估模型性能"""
        # 仅对 texts 进行分词，labels 已为数字
        encodings = self.tokenizer(
            texts,
            truncation=True,
            padding=True,
            max_length=128,
            return_tensors="pt"
        )
        dataset = self.create_dataset(encodings, labels)

        trainer = Trainer(model=self.model)
        predictions = trainer.predict(dataset)
        pred_labels = np.argmax(predictions.predictions, axis=-1)
        true_labels = labels  # 直接使用数字标签

        logger.info("分类报告:")
        logger.info(classification_report(
            true_labels,
            pred_labels,
            target_names=["通用知识", "专业咨询"]
        ))
        logger.info("混淆矩阵:")
        logger.info(confusion_matrix(true_labels, pred_labels))

    def predict_category(self, query):
        # 检查模型是否加载
        if self.model is None:
            # 模型未加载，记录错误
            logger.error("模型未训练或加载")
            # 默认返回通用知识
            return "通用知识"
        # 对查询进行编码
        encoding = self.tokenizer(query, truncation=True, padding=True, max_length=128, return_tensors="pt")
        # 将编码移到指定设备
        encoding = {k: v.to(self.device) for k, v in encoding.items()}
        # 不计算梯度，进行预测
        with torch.no_grad():
            # 获取模型输出
            outputs = self.model(**encoding)
            # 获取预测结果
            prediction = torch.argmax(outputs.logits, dim=1).item()
        # 根据预测结果返回类别
        return "专业咨询" if prediction == 1 else "通用知识"


if __name__ == "__main__":

    bert_dawnload.load_bert()
    # 初始化分类器
    classifier = QueryClassifier(model_path="bert_query_classifier")

    # 训练模型
    classifier.train_model(data_file='model_generic_5000.json')
    # 示例预测
    # test_queries = [
    #     "AI学科的课程大纲是什么",
    #     "JAVA课程费用多少？",
    #     "5*9等于多少？",
    #     "AI培训有哪些老师？"
    # ]
    # for query in test_queries:
    #     category = classifier.predict_category(query)
    #     print(f"查询: {query} -> 分类: {category}")
